'''
This Python file contains several helper functions to generate the plots used in the paper.
'''

### Imports
import geopandas as gpd
import numpy as np
import matplotlib
#matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import json
import os
import pandas as pd

import os
from functools import partial
import json
import csv

import geopandas as gpd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import networkx as nx
import pandas as pd
import numpy as np
from tqdm import tqdm_notebook
import random
import seaborn as sns

### Helper functions
def binning_seats(data):
    seats_data = []
    for i in range(len(data)):
        seats_data.append(data["seats"][i])
    return max(seats_data) - min(seats_data) + 1


def make_full_lists(data):
    all_seats = []
    all_mms = []
    all_pgs = []
    for i in range(len(data)):
        all_seats.append(data["seats"][i])
        all_mms.append(data["mm"][i])
        all_pgs.append(data["pg"][i])
    return all_seats, all_pgs, all_mms


def make_winnowed_lists(data, inliers, pg_bound, mm_bound):
    mm_routliers = []
    mm_doutliers = []
    pg_routliers = []
    pg_doutliers = []
    vs_routliers = []
    vs_doutliers = []
    ideal_pg_seats = []
    ideal_mm_seats = []

    for i in range(len(data)):
        if data["seats"][i] < inliers[0]:
            mm_doutliers.append(data["mm"][i])
            pg_doutliers.append(data["pg"][i])
            vs_doutliers.append(data["vs"][i])
        if data["seats"][i] > inliers[1]:
            mm_routliers.append(data["mm"][i])
            pg_routliers.append(data["pg"][i])
            vs_routliers.append(data["vs"][i])
        if data["pg"][i] < pg_bound:
            ideal_pg_seats.append(data["seats"][i])
        if abs(data["mm"][i]) < mm_bound:
            ideal_mm_seats.append(data["seats"][i])

    outlying_mms = [mm_routliers, mm_doutliers]
    outlying_pgs = [pg_routliers, pg_doutliers]

    return outlying_pgs, outlying_mms, len(mm_doutliers) + len(mm_routliers), ideal_pg_seats, ideal_mm_seats

def visualize(state, data, inliers, num_dist, pg_bound, mm_bound,state_name,election_name,bound_name, newdir, showPlots=True):

    if state[0] == "U" and election_name == "SEN16":
        v_bar = 0.7155
        use_Prop = True
    elif state[0] == "T" and election_name == "SEN12":
        v_bar = 0.5815
        use_Prop = True
    elif state[0] == "N" and election_name == "SEN16":
        v_bar = 0.5302
        use_Prop = True
    else:
        v_bar = 0
        use_Prop = False

    # Setting up figures
    fig1, seats = plt.subplots(figsize=(16,4))
    fig2, pg_seats = plt.subplots(figsize=(16,4))
    fig3, mm_seats = plt.subplots(figsize=(16,4))
    fig4, ax = plt.subplots(2,2, figsize=(16,8), sharex="col", sharey="row")
    figa, mm_dist = plt.subplots(figsize=(8,4))
    figb, mm_out_dist = plt.subplots(figsize=(8,4))
    figc, pg_dist = plt.subplots(figsize=(8,4))
    figd, pg_out_dist = plt.subplots(figsize=(8,4))

    seats_figs = [seats, pg_seats, mm_seats]

    # Spacing of subplots
    plt.subplots_adjust(wspace=0.6, hspace=0.6)

    # Colors
    party_colors = ["red", "blue"]
    neutral = "lightgray"

    # Variables
    bins_num = 20
    out_bins_num = binning_seats(data)
    num_steps = len(data)

    # Formatting data
    all_seats, all_pgs, all_mms = make_full_lists(data)
    outlying_pgs, outlying_mms, num_outliers, ideal_pg_seats, ideal_mm_seats = make_winnowed_lists(data, inliers, pg_bound, mm_bound)

    # More Variables
    num_best_pgs = len(ideal_pg_seats)
    num_best_mms = len(ideal_mm_seats)
    pct_outliers = round(num_outliers / num_steps * 100,1)
    pct_best_pgs = round((num_best_pgs/num_steps) * 100,1)
    pct_best_mms = round((num_best_mms/num_steps) * 100,1)

    custom_lines = [[Line2D([0], [0], color="blue", lw=4)],
                   [Line2D([0], [0], color="blue", lw=4)],
                   [Line2D([0], [0], color="blue", lw=4)]]

    outlier_labels = ["D-favoring plans"]
    if state[0] == 'T' or state[0] == 'N':
        for i in range(3):
            custom_lines[i].append(Line2D([0], [0], color="red", lw=4))
        outlier_labels.append("R-favoring plans")

    if use_Prop:
        for i in range(3):
            custom_lines[i].append(Line2D([0], [0], color="darkgreen", lw=2, ls='solid'))
        outlier_labels.append(str(v_bar*100) + "% statewide R share")

    # Titles
    all_plans_title = state + ": " + str(num_steps) + " ReCom plans"
    outlying_seats_title = state + ": " + str(num_outliers) + " outlier plans " + "(" + str(pct_outliers) + "%)"
    all_seats_title = state + ": Republican seats won: " + str(num_steps) + " ReCom plans"
    ideal_pgs_title = state + " Republican seats won: " + str(num_best_pgs) + " plans with PG < " + str(pg_bound) + " (" + str(pct_best_pgs) + "% of all plans)"
    ideal_mms_title = state + " Republican seats won: " + str(num_best_mms) + " plans with |MM| < " + str(mm_bound) + " (" + str(pct_best_mms) + "% of all plans)"

    seats_titles = [all_seats_title, ideal_pgs_title, ideal_mms_title]

    # Plotting the data into the figures
    n, bins, patches = seats.hist(all_seats, bins=range(num_dist+2), color=neutral)#was plus 1 missed for UT?
    for i in range(num_dist):
        if i < inliers[0]:
            patches[i].set_fc("blue")
        if i > inliers[1]:
            patches[i].set_fc("red")
    n, bins, patches = pg_seats.hist(ideal_pg_seats, bins=range(num_dist+2), color=neutral)#was plus 1 missed for UT?
    for i in range(num_dist):
        if i < inliers[0]:
            patches[i].set_fc("blue")
        if i > inliers[1]:
            patches[i].set_fc("red")
    n, bins, patches = mm_seats.hist(ideal_mm_seats, bins=range(num_dist+2), color=neutral)#was plus 1 missed for UT?
    for i in range(num_dist):
        if i < inliers[0]:
            patches[i].set_fc("blue")
        if i > inliers[1]:
            patches[i].set_fc("red")

    ax[0,0].hist(all_mms, bins=2*bins_num, color=neutral)
    ax[0,1].hist(all_pgs, bins=2*bins_num, color=neutral)
    ax[1,0].hist(outlying_mms, bins=bins_num, color=party_colors)
    ax[1,1].hist(outlying_pgs, bins=bins_num, color=party_colors)

    # Aesthetics for figures
    for i in seats_figs:
#         i.set_ylabel("Frequency")
#         i.set_xlabel("# Seats Won")
        i.set_xlim(0, num_dist+1)
        i.set_xticks([x+.5 for x in range(num_dist + 1)])
        i.set_xticklabels(range(num_dist + 1))
        if use_Prop:
            i.axvline(x=(num_dist*v_bar) + 0.5, color="darkgreen", linestyle="solid", linewidth=2)
        for j in range(-2,num_dist+2):
            i.axvline(x=(j), color="gray", linewidth=0.5)
        i.grid(axis='y')

    for i in range(3):
#         seats_figs[i].set_title(seats_titles[i])
        leg1 = seats_figs[i].legend(custom_lines[i], outlier_labels, prop={"size":16}, loc="upper left")
        seats_figs[i].add_artist(leg1)

    for i in [0,1]:
        for j in [0,1]:
            ax[i,j].grid()
            ax[i,j].set_ylabel("Frequency")
            ax[i,j].tick_params(labelleft=True, labelbottom=True)
            if j == 0:
                ax[i,j].set_xlabel("Mean-Median Score")
                ax[i,j].axvline(x=0, color="black", linewidth=2, linestyle="dashed", label="MM = 0 (ideal)")
            else:
                ax[i,j].set_xlabel("Partisan Gini Score")
                ax[i,j].axvline(x=0, color="black", linewidth=2, linestyle="dashed", label="PG = 0 (ideal)")
            if i == 0:
                ax[i,j].set_title(all_plans_title)
            else:
                ax[i,j].set_title(outlying_seats_title)
            ax[i,j].legend()

    # Splitting up distribution plots
    mm_dist.hist(all_mms, bins=2*bins_num, color=neutral)
    mm_out_dist.hist(outlying_mms, bins=bins_num, color=party_colors)
    mm_dist.axvline(x=0, color="black", linewidth=2, linestyle="dashed", label="MM = 0 (ideal)")
    mm_out_dist.axvline(x=0, color="black", linewidth=2, linestyle="dashed", label="MM = 0 (ideal)")
    l1, r1 = mm_dist.get_xlim()
    l2, r2 = mm_out_dist.get_xlim()
    mm_dist.set_xlim(min(l1,l2), max(r1, r2))
    mm_out_dist.set_xlim(min(l1,l2), max(r1, r2))

    pg_dist.hist(all_pgs, bins=2*bins_num, color=neutral)
    pg_out_dist.hist(outlying_pgs, bins=bins_num, color=party_colors)
    pg_dist.axvline(x=0, color="black", linewidth=2, linestyle="dashed", label="PG = 0 (ideal)")
    pg_out_dist.axvline(x=0, color="black", linewidth=2, linestyle="dashed", label="PG = 0 (ideal)")
    l1, r1 = pg_dist.get_xlim()
    l2, r2 = pg_out_dist.get_xlim()
    pg_dist.set_xlim(min(l1,l2), max(r1, r2))
    pg_out_dist.set_xlim(min(l1,l2), max(r1, r2))

    mm_dist.grid()
    pg_dist.grid()
    mm_out_dist.grid()
    pg_out_dist.grid()

    mm_dist.tick_params(labelleft=True, labelbottom=True)
    pg_dist.tick_params(labelleft=True, labelbottom=True)
    mm_out_dist.tick_params(labelleft=True, labelbottom=True)
    pg_out_dist.tick_params(labelleft=True, labelbottom=True)


    mm_dist.legend(prop={"size":16})
    mm_out_dist.legend(prop={"size":16})
    pg_dist.legend(prop={"size":16})
    pg_out_dist.legend(prop={"size":16})

    # Save all figures
    fig1.savefig(newdir+state_name+election_name+"_"+bound_name+"_all_seats.png", dpi=1200, bbox_inches='tight')
    fig2.savefig(newdir+state_name+election_name+"_"+bound_name+"_pg_seats.png", dpi=1200, bbox_inches='tight')
    fig3.savefig(newdir+state_name+election_name+"_"+bound_name+"_mm_seats.png", dpi=1200, bbox_inches='tight')
    fig4.savefig(newdir+state_name+election_name+"_"+bound_name+"_all_dists.png", dpi=1200)
    figa.savefig(newdir+state_name+election_name+"_"+bound_name+"_mm_dist.png", dpi=1200, bbox_inches='tight')
    figb.savefig(newdir+state_name+election_name+"_"+bound_name+"_mm_out_dist.png", dpi=1200, bbox_inches='tight')
    figc.savefig(newdir+state_name+election_name+"_"+bound_name+"_pg_dist.png", dpi=1200, bbox_inches='tight')
    figd.savefig(newdir+state_name+election_name+"_"+bound_name+"_pg_out_dist.png", dpi=1200, bbox_inches='tight')

    if showPlots:
        fig1.show()
        fig2.show()
        fig3.show()
        fig4.show()

    plt.close(fig1)
    plt.close(fig2)
    plt.close(fig3)
    plt.close(fig4)

def runElection(v, UPS=0):
    '''Helper function for generating Seats-Votes curve. '''
    shifted_v = [i + UPS for i in v]
    votes = np.mean(shifted_v)
    seats = sum(i > 0.5 for i in shifted_v) / len(shifted_v)
    return votes, seats

def plot_sv(v, with_mirror=False):
    '''Generate the Seats-Votes curve generated by a vote share vector v. '''
    seats, votes = [], []
    x, y = runElection(v)
    for UPS in np.arange(-1, 1, 0.001):
        v_, s_ = runElection(v, UPS)
        votes.append(v_)
        seats.append(s_)
        
    fig, ax = plt.subplots(figsize=(8,8))
        
    ax.plot(votes,
            seats,
            color="red",
            lw=4,
            zorder=1)
    ax.set_xlim(0,1)
    ax.set_xlabel("V", fontsize=20)
    ax.set_ylim(0,1)
    ax.set_ylabel("S", fontsize=20).set_rotation(0)
    plt.grid()
    
    if with_mirror:
        ax.plot([1 - i for i in votes],
                [1 - i for i in seats],
                color="blue",
                lw=4,
                zorder=1)
        plt.savefig(f"../plots/High_Res_Plots/{v}_mirrored.png", dpi=1200, bbox_inches='tight')
    else:
        ax.scatter(x, 
               y, 
               color="black",
               lw=5,
               zorder=2)
        ax.scatter(0.5, 
                0.5, 
                color="black",
                marker="*",
                s=200,
                zorder=2)
        plt.savefig(f"../plots/High_Res_Plots/{v}.png", dpi=1200, bbox_inches='tight')
    return

def discrep(v, i):
    '''Helper function for PG function. '''
    v_bar = np.mean(v)
    k = len(v)
    return (0.5 * (v[i-1] + v[k - i])) - v_bar

def PG(v):
    '''Calculate PG score, given a sorted vote share vector v. '''
    if (v != sorted(v)):
        raise ValueError("Error: `v` must be a sorted vector")
    k = len(v)
    s = 0
    for i in range(1, k+1):
        s += abs(discrep(v, i))
    return (2/k) * s

def MM(v):
    '''Calculate Mean-Median score, given a vote shrae vector v. '''
    return np.percentile(v, 50) - np.mean(v)